import numpy as np
import matplotlib.pyplot as plt

##############################################################################################################################
                                             # Import data
##############################################################################################################################

#### KL_UCB_Transfer

data = np.load("regret_KL_UCB_Transfer_Sim1prior3.npz")
tsaveS3R1   = data["tsave"]     # shape (M,)
RegretS3R1 = data["Regret"]   # shape (R, M)
R, M = RegretS3R1.shape

mean_regretS3R1 = RegretS3R1.mean(axis=0)                     # shape (M,)
sem_regretS3R1  = RegretS3R1.std(axis=0, ddof=1) / np.sqrt(R)  # standard error




#### KL_UCB_Transfer

data = np.load("regret_AST_UCB_Sim3prior3.npz")
tsaveS3R2   = data["tsave"]     # shape (M,)
RegretS3R2 = data["Regret"]   # shape (R, M)
R, M = RegretS3R2.shape

mean_regretS3R2 = RegretS3R2.mean(axis=0)                     # shape (M,)
sem_regretS3R2  = RegretS3R2.std(axis=0, ddof=1) / np.sqrt(R)  # standard error

##############################################################################################################################
                                             # Plot Simulation 3
##############################################################################################################################

plt.figure(figsize=(5,3))


#S3R1
plt.fill_between(tsaveS3R1,
                 mean_regretS3R1 - sem_regretS3R1,
                 mean_regretS3R1 + sem_regretS3R1,
                 alpha=0.3,color = "b"
                 )
plt.plot(tsaveS3R1, mean_regretS3R1, lw=1.5, label="KL_UCB_Transfer",color = "b",linestyle='-')

#S3R2
plt.fill_between(tsaveS3R2,
                 mean_regretS3R2 - sem_regretS3R2,
                 mean_regretS3R2 + sem_regretS3R2,
                 alpha=0.3,color = "r"
                 )
plt.plot(tsaveS3R2, mean_regretS3R2, lw=1.5, label="AST_UCB",color = "r",linestyle='--')




plt.xscale('log')
plt.xlabel('$T$')
plt.ylabel('$R_T$')
plt.legend()
plt.grid(True, which='both', ls='--', alpha=0.4)
plt.tight_layout()
plt.savefig("plot3ACML.pdf", format="pdf")
plt.show()